import os
import sys
sys.path.insert(0, os.getcwd())

import wandb
from analysis.base import Gen_Analysis, np_softmax
import torch
import numpy as np
import seaborn as sns
import torch.functional as F
import matplotlib.pyplot as plt
import pandas as pd
import os
import csv
from scipy import linalg
from tqdm import tqdm


#sns.set(style="whitegrid")


class ISAnalysis(Gen_Analysis):

    def __init__(self, rootdir, args):
        super().__init__(original=None, rootdir=rootdir,
                         args=args)

        self.splits=10
        self.results_exp = self.results_all[args.exp]

    @torch.no_grad()
    def _collect_data(self, model, loader, desc=None):
        model.eval()
        #data_dict = {"features": []}
        features = []
        for x in tqdm(loader, desc="Collecting features for Analysis" if desc is None else desc):
            if type(x) == list :
                x = x[0]

            x = x.cuda(non_blocking=True)
            # y = y.cuda(non_blocking=True)

            pred = model(x)[-1]

            # If model output is not scalar, apply global spatial average pooling.
            # This happens if you choose a dimensionality not equal 2048.

            pred = pred.cpu().data.numpy()
            features.extend(pred)

        data_dict = {"probs" : np.stack(features)}

        return data_dict

    def _compute_staistics(self, activation):
        mu = np.mean(activation, axis=0);
        sigma = np.cov(activation, rowvar=False)

        return mu, sigma

    def calculate_inception_score(self, probs):
        # Inception Score
        scores = []
        splits = self.splits

        for i in range(splits):
            part = probs[
                (i * probs.shape[0] // splits):
                ((i + 1) * probs.shape[0] // splits), :]

            kl = part * (
                np.log(part) -
                np.log(np.expand_dims(np.mean(part, 0), 0)))
            kl = np.mean(np.sum(kl, 1))
            scores.append(np.exp(kl))

        inception_score, std = (np.mean(scores), np.std(scores))
        return inception_score, std


    def get_xys(self, data, desc="generated"):

        if type(data) == tuple :
            dloader = self._data_to_loader(data);
        else :
            dloader = data

        self.results_exp[desc] = self._collect_data(model=self.model, loader=dloader,
                           desc=desc)

        probs = self.results_exp[desc]["probs"];
        is_score, std = self.calculate_inception_score(probs)

        return is_score, std

    def plot(self, rootdir=None):

        for key in self.results_exp.keys() :

            probs = self.results_exp[key]["probs"];

            is_score, std = self.calculate_inception_score(probs)

            print("IS Score {}: {}/{}".format(key, is_score, std))


    def _file_name(self, postfix, rootdir=None):
        return os.path.join(self.rootdir if rootdir is None else rootdir, "fid {}".format(postfix))


    def print(self,):
        return # nothing to print

    def to_csv(self, rootdir=None):
        return # nothing to print

if __name__ == '__main__':
    from datatool.datatool import get_dl_tr, data_path
    from torchvision import datasets
    from torch.utils.data import DataLoader
    import argparse

    from torchvision import transforms
    from tools.utils import init_distributed_mode

    parser = argparse.ArgumentParser()
    parser.add_argument('--num-workers', type=int, default=0,
                        help='bacth_size_per_gpu')
    parser.add_argument('--bsz', type=int, default=6,
                        help='bacth_size_per_gpu')
    parser.add_argument('--debug', type=int, default=1,
                        help='debug or not')
    parser.add_argument('--x_sigma', type=float, default=0.001,
                        help='x_sigma')
    parser.add_argument('--exp', type=str, default="test",
                        help='CIFAR10 ori')
    parser.add_argument("--dist_url", default="env://", type=str, help="""url used to set up
            distributed training; see https://pytorch.org/docs/stable/distributed.html""")
    parser.add_argument("--local_rank", default=0, type=int, help="Please ignore and do not set this argument.")
    parser.add_argument("--device" , type=str, default="cuda")

    args = parser.parse_args();

    dt_cifar = datasets.CIFAR10(root=data_path["cifar10"] + 'train/', train=True,
                             download=True, transform=transforms.Compose([transforms.ToTensor()
                            ]))
    dt_svhn = datasets.SVHN(root=data_path["svhn"] + 'train/', split="train",
                          download=True, transform=transforms.Compose([transforms.ToTensor()
                            ]))

    dl_cifar = torch.utils.data.DataLoader(
        dt_cifar,
        batch_size=args.bsz,
        num_workers=args.num_workers
    )
    # dl_svhn = torch.utils.data.DataLoader(
        # dt_svhn,
        # batch_size=args.bsz,
        # num_workers=args.num_workers
    # )

    is_anl = ISAnalysis(rootdir="./out", args=args)
    # is_anl.get_xys((dt_svhn.data[:200],), desc="svhn2")
    is_anl.get_xys(dl_cifar, desc="cifar2") # 바로 리턴

    is_anl.plot()
